from llm.llm_wrapper import LLMWrapper
from utils.logger import Logger
import core.agent_prompt as AgentPrompt
from utils.json_utils import extract_json

from llm.message import (
    Message,
    MessageContent,
    ROLE_SYSTEM,
    ROLE_USER,
    ROLE_ASSISTANT,
    TYPE_SETTING,
    TYPE_CONTEXT,
    TYPE_CONTENT,
)

import os
import json
from typing import Dict

class DomainCategory:

    def __init__(self, logger: Logger, llm: LLMWrapper, config={}):
        self.logger = logger
        self.llm = llm
        self.config = config

        self.category_map: Dict = self._init_categories()

    def _init_categories(self):        
        cache_dir = self.config.get('cache_dir', None)
        if not cache_dir:
            raise ValueError("Unknwo cache dir")

        if not os.path.exists(cache_dir):
            os.makedirs(cache_dir, exist_ok=True)

        filepath = os.path.join(cache_dir, 'category_map.json')
        if not os.path.exists(filepath):
            return {}

        with open(filepath, "r", encoding="utf-8") as f:
            try:
                category_map = json.load(f)
            except Exception as e:
                self.logger.log_exception(e)
                return {}
        return category_map
    
    def _update_categories(self):
        cache_dir = self.config.get('cache_dir', None)
        if not cache_dir:
            raise ValueError("Unknwo cache dir")

        if not os.path.exists(cache_dir):
            os.makedirs(cache_dir, exist_ok=True)

        filepath = os.path.join(cache_dir, 'category_map.json')

        with open(filepath, "w", encoding="utf-8") as f:
            f.write(json.dumps(self.category_map))

    def classify(self, query):
        prompt = AgentPrompt.domain_classify_prompt(self.category_map, query)
        messages = [Message(ROLE_USER, [MessageContent(TYPE_CONTENT, prompt)])]
        response = self.llm.generate(messages)
        self.logger.info(f"Classify Response: {response}")
        data = extract_json(response)

        primary_category = data['primary_category']
        secondary_category = data['secondary_category']
        if not primary_category or not secondary_category:
            raise ValueError("Invalid category")

        primary_category_info: Dict = self.category_map.get(primary_category, {})

        if secondary_category not in primary_category_info.keys():
            desc = data['new_category_desc']
            secondary_category_info = {
                'desc': desc
            }
            primary_category_info[secondary_category] = secondary_category_info
            self.category_map[primary_category] = primary_category_info
        else:
            desc = self.category_map[primary_category][secondary_category]['desc']

        return {
            'primary_category': primary_category,
            'secondary_category': secondary_category,
            'desc': desc
        }
    
    def __del__(self):
        self._update_categories()
